import torch
from torch import nn
import pickle
from parser1 import args
from function import *
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, embed_size, num_heads,recall_item):
        super(CrossAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_size = embed_size // num_heads
        self.image_feats = torch.from_numpy(np.load(args.data_path + '{}/image_feat.npy'.format(args.dataset))).to(dtype=torch.float32).cuda()
        self.text_feats = torch.from_numpy(np.load(args.data_path + '{}/text_feat.npy'.format(args.dataset))).cuda().to(dtype=torch.float32)
        self.img_size = self.image_feats.shape[1]
        self.txt_size = self.text_feats.shape[1]
        self.len_item = self.image_feats.shape[0]
        self.batch_size = args.batch_size
        self.recall_item = recall_item
        # 权重矩阵
        self.i_WQ = nn.Linear(self.img_size, embed_size, bias=False)
        self.i_WK = nn.Linear(self.img_size, embed_size, bias=False)
        self.i_WV = nn.Linear(self.img_size, embed_size, bias=False)
        self.i_WO = nn.Linear(self.embed_size, embed_size, bias=False)

        self.t_WQ = nn.Linear(self.txt_size, embed_size, bias=False)
        self.t_WK = nn.Linear(self.txt_size, embed_size, bias=False)
        self.t_WV = nn.Linear(self.txt_size, embed_size, bias=False)
        self.t_WO = nn.Linear(self.embed_size, embed_size, bias=False)



    def compute_cross(self, type, train_user, train_items, max_inter):
        n_train_user = len(train_user)
        # 计算 Q、K、V
        #print(self.i_WK.weight)
        if type == "image":
            intered_item = torch.full((n_train_user, max_inter, self.img_size), 0.).cuda()
            for i, items in enumerate(train_items):
                intered_item[i, :len(items), :] = self.image_feats[items]
            Q = self.i_WQ(self.image_feats[self.recall_item[train_user]]).view(n_train_user, -1 , self.num_heads, self.head_size)
            K = self.i_WK(intered_item).view(n_train_user, max_inter, self.num_heads, self.head_size)
            V = self.i_WV(intered_item).view(n_train_user, max_inter, self.num_heads, self.head_size)
        elif type == "text":
            intered_item = torch.full((n_train_user, max_inter, self.txt_size), 0.).cuda()
            for i, items in enumerate(train_items):
                intered_item[i, :len(items), :] = self.text_feats[items]
            Q = self.t_WQ(self.text_feats[self.recall_item[train_user]]).view(n_train_user, -1 , self.num_heads, self.head_size)
            K = self.t_WK(intered_item).view(n_train_user, max_inter, self.num_heads, self.head_size)
            V = self.t_WV(intered_item).view(n_train_user, max_inter, self.num_heads, self.head_size)
        # 调整形状
        Q = Q.permute(2, 0, 1, 3)  # (num_heads, batch_size, seq_len, head_size)
        K = K.permute(2, 0, 1, 3)
        V = V.permute(2, 0, 1, 3)

        # 计算注意力得分
        att_scores = torch.matmul(Q.to(dtype=torch.float), K.permute(0, 1, 3, 2)) / torch.sqrt(
            torch.tensor(self.head_size).float())
        att_scores = torch.where(att_scores==0., float('-inf'), att_scores)
        #print(att_scores)
        # 计算注意力权重
        att_weights = F.softmax(att_scores, dim=-1)
        #print(att_weights)
        # 使用注意力权重对 V 进行加权和
        V = torch.where(torch.isnan(V), 0,V)
        att_output = torch.matmul(att_weights, V).permute(1, 2, 0, 3).contiguous().view(n_train_user, -1, self.embed_size)
        #print(att_output)
        # 进行线性变换
        if type == "text":
            out_put=self.t_WO(att_output)
        elif type == "image":
            out_put=self.i_WO(att_output)
        return out_put


    def compute_self(self, type, train_user, train_items, max_inter,pos_item):
        n_train_user = len(train_user)
        # 计算 Q、K、V
        if type == "image":
            intered_item = torch.full((n_train_user, max_inter, self.img_size), 0.).cuda()
            for i, items in enumerate(train_items):
                intered_item[i, :len(items), :] = self.image_feats[items]
            Q = self.i_WQ(self.image_feats[pos_item]).view(n_train_user, -1 , self.num_heads, self.head_size)
            K = self.i_WK(intered_item).view(n_train_user, max_inter, self.num_heads, self.head_size)
            V = self.i_WV(intered_item).view(n_train_user, max_inter, self.num_heads, self.head_size)
        elif type == "text":
            intered_item = torch.full((n_train_user, max_inter, self.txt_size), 0.).cuda()
            for i, items in enumerate(train_items):
                intered_item[i, :len(items), :] = self.text_feats[items]
            Q = self.t_WQ(self.text_feats[pos_item]).view(n_train_user, -1 , self.num_heads, self.head_size)
            K = self.t_WK(intered_item).view(n_train_user, max_inter, self.num_heads, self.head_size)
            V = self.t_WV(intered_item).view(n_train_user, max_inter, self.num_heads, self.head_size)
        # 调整形状
        Q = Q.permute(2, 0, 1, 3)  # (num_heads, batch_size, seq_len, head_size)
        K = K.permute(2, 0, 1, 3)
        V = V.permute(2, 0, 1, 3)

        # 计算注意力得分
        att_scores = torch.matmul(Q.to(dtype=torch.float), K.permute(0, 1, 3, 2)) / torch.sqrt(
            torch.tensor(self.head_size).float())
        att_scores = torch.where(att_scores==0., float('-inf'), att_scores)
        # 计算注意力权重
        att_weights = F.softmax(att_scores, dim=-1)
        # 使用注意力权重对 V 进行加权和
        V = torch.where(torch.isnan(V), 0, V)
        att_output = torch.matmul(att_weights, V).permute(1, 2, 0, 3).contiguous().view(n_train_user, -1, self.embed_size)
        # 进行线性变换
        if type == "text":
            out_put=self.t_WO(att_output)
        elif type == "image":
            out_put=self.i_WO(att_output)
        return out_put

    def forward(self, train_user,train_items,max_inter,pos_item,is_test):
        if is_test:
            cross_image = self.compute_cross('image', train_user, train_items, max_inter)
            cross_text = self.compute_cross('text', train_user, train_items, max_inter)
            self_att = 0
        else:
            cross_image = self.compute_cross('image',train_user,train_items,max_inter)
            cross_text = self.compute_cross('text',train_user,train_items,max_inter)
            self_image = self.compute_self('image',train_user,train_items,max_inter,pos_item)
            self_text = self.compute_self('text',train_user,train_items,max_inter,pos_item)
            self_att = torch.cat([self_image,self_text],-1)
        #print(cross_image,'??')
        #print(self_image,'!!')
        i_v = self.i_WV(self.image_feats)
        t_v = self.t_WV(self.text_feats)
        return torch.cat([cross_image,cross_text],-1),self_att,torch.cat([i_v,t_v],-1)


class Generater(nn.Module):
    def __init__(self,emb_size,user_emb,item_emb,recall_item):
        super(Generater,self).__init__()
        self.n_layers = 2
        self.emb_size = emb_size
        self.ui_graph = pickle.load(open(args.data_path + args.dataset + '/train_mat', 'rb'))
        self.n_user = self.ui_graph.shape[0]
        self.n_item = self.ui_graph.shape[1]
        #print("graph:",self.n_item)
        #self.user_emb = user_emb
        #self.item_emb = item_emb
        A = sp.dok_matrix((self.n_user+self.n_item,self.n_user+self.n_item),dtype=np.float32)
        A = A.tolil()
        R = self.ui_graph.todok()
        A[:self.n_user,self.n_user:] = R
        A[self.n_user:,:self.n_user] = R.T
        sumArr = (A>0).sum(axis=1)
        diag = np.array(sumArr.flatten())[0]+1e-7
        diag = np.power(diag,-0.5)
        D = sp.diags(diag)
        L = D*A*D
        self.L = sp.coo_matrix(L)
        self.cross_attention = CrossAttention(self.emb_size,args.head_num,recall_item)
        #print(self.ui_graph.shape)

    def forward(self,train_user,train_items,max_inter,pos_item,is_test = False):
        # all_emb = torch.cat([self.user_emb.weight,self.item_emb.weight])
        # emb_lsit = [all_emb]
        # for layer in range(self.n_layers):
        #     all_emb = torch.sparse.mm(graph_drop(matrix_to_tensor(self.L),args.keep_rate),all_emb)
        #     emb_lsit.append(all_emb)
        # all_emb = torch.mean(torch.stack(emb_lsit,dim=1),dim=1)
        # user_all_embeddings, item_all_embeddings = torch.split(all_emb, [self.n_user, self.n_item])
        neg_user_feature,pos_user_feature,item_feature = self.cross_attention(train_user,train_items,max_inter,pos_item,is_test)
        #print(neg_user_feature.shape,pos_user_feature.shape,item_feature.shape,args.Recall_rate*len(item_feature))
        #print(item_feature)
        #print(neg_user_feature)
        return neg_user_feature,pos_user_feature,item_feature


class Discriminator(nn.Module):
    def __init__(self,dim):
        super(Discriminator, self).__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_features=dim, out_features=round(dim/2)),
            torch.nn.BatchNorm1d(round(dim / 2)),
            torch.nn.Dropout(0.3),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(round(dim/2), out_features=round(dim/4)),
            torch.nn.BatchNorm1d(round(dim/4)),
            torch.nn.Dropout(0.3),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=round(dim / 4), out_features=round(dim / 8)),
            torch.nn.BatchNorm1d(round(dim / 8)),
            torch.nn.Dropout(0.3),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(round(dim / 8), out_features=round(dim / 16)),
            torch.nn.BatchNorm1d(round(dim / 16)),
            torch.nn.Dropout(0.3),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(round(dim/16), out_features=1),
            torch.nn.Sigmoid()
        )

    def forward(self,input):
        x = self.fc(input)
        return x



